U-Net

Goals

In this notebook, you're going to implement a U-Net for a biomedical imaging segmentation task. Specifically, you're going to be labeling neurons, so one might call this a neural neural network! ;)

Note that this is not a GAN, generative model, or unsupervised learning task. This is a supervised learning task, so there's only one correct answer (like a classifier!) You will see how this component underlies the Generator component of Pix2Pix in the next notebook this week.

Learning Objectives

  1. Implement your own U-Net.
  2. Observe your U-Net's performance on a challenging segmentation task.

Getting Started

You will start by importing libraries, defining a visualization function, and getting the neural dataset that you will be using.

Dataset

For this notebook, you will be using a dataset of electron microscopy images and segmentation data. The information about the dataset you'll be using can be found here!

Arganda-Carreras et al. "Crowdsourcing the creation of image segmentation algorithms for connectomics". Front. Neuroanat. 2015. https://www.frontiersin.org/articles/10.3389/fnana.2015.00142/full

dataset example

In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    # image_shifted = (image_tensor + 1) / 2
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=4)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

U-Net Architecture

Now you can build your U-Net from its components. The figure below is from the paper, U-Net: Convolutional Networks for Biomedical Image Segmentation, by Ronneberger et al. 2015. It shows the U-Net architecture and how it contracts and then expands.

Figure 1 from the paper, U-Net: Convolutional Networks for Biomedical Image Segmentation

In other words, images are first fed through many convolutional layers which reduce height and width while increasing the channels, which the authors refer to as the "contracting path." For example, a set of two 2 x 2 convolutions with a stride of 2, will take a 1 x 28 x 28 (channels, height, width) grayscale image and result in a 2 x 14 x 14 representation. The "expanding path" does the opposite, gradually growing the image with fewer and fewer channels.

Contracting Path

You will first implement the contracting blocks for the contracting path. This path is the encoder section of the U-Net, which has several downsampling steps as part of it. The authors give more detail of the remaining parts in the following paragraph from the paper (Renneberger, 2015):

The contracting path follows the typical architecture of a convolutional network. It consists of the repeated application of two 3 x 3 convolutions (unpadded convolutions), each followed by a rectified linear unit (ReLU) and a 2 x 2 max pooling operation with stride 2 for downsampling. At each downsampling step we double the number of feature channels.

Optional hints for ContractingBlock 1. Both convolutions should use 3 x 3 kernels. 2. The max pool should use a 2 x 2 kernel with a stride 2.
In [2]:
m = nn.Conv2d(16, 32, 3, stride=1)
In [3]:
input = torch.randn(20, 16, 50, 100)
In [4]:
output = m(input)
output.shape
Out[4]:
torch.Size([20, 32, 48, 98])
In [5]:
# UNQ_C1 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED CLASS: ContractingBlock
class ContractingBlock(nn.Module):
    '''
    ContractingBlock Class
    Performs two convolutions followed by a max pool operation.
    Values:
        input_channels: the number of channels to expect from a given input
    '''
    def __init__(self, input_channels):
        super(ContractingBlock, self).__init__()
        # You want to double the number of channels in the first convolution
        # and keep the same number of channels in the second.
        #### START CODE HERE ####
        self.conv1 = nn.Conv2d(input_channels, 2*input_channels, kernel_size=3)
        self.conv2 = nn.Conv2d(2*input_channels, 2*input_channels, kernel_size=3)
        self.activation = nn.ReLU(inplace = True)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        #### END CODE HERE ####

    def forward(self, x):
        '''
        Function for completing a forward pass of ContractingBlock: 
        Given an image tensor, completes a contracting block and returns the transformed tensor.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        x = self.conv1(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.activation(x)
        x = self.maxpool(x)
        return x
    
    # Required for grading
    def get_self(self):
        return self
In [6]:
#UNIT TEST
def test_contracting_block(test_samples=100, test_channels=10, test_size=50):
    test_block = ContractingBlock(test_channels)
    test_in = torch.randn(test_samples, test_channels, test_size, test_size)
    test_out_conv1 = test_block.conv1(test_in)
    # Make sure that the first convolution has the right shape
    assert tuple(test_out_conv1.shape) == (test_samples, test_channels * 2, test_size - 2, test_size - 2)
    # Make sure that the right activation is used
    assert torch.all(test_block.activation(test_out_conv1) >= 0)
    assert torch.max(test_block.activation(test_out_conv1)) >= 1
    test_out_conv2 = test_block.conv2(test_out_conv1)
    # Make sure that the second convolution has the right shape
    assert tuple(test_out_conv2.shape) == (test_samples, test_channels * 2, test_size - 4, test_size - 4)
    test_out = test_block(test_in)
    # Make sure that the pooling has the right shape
    assert tuple(test_out.shape) == (test_samples, test_channels * 2, test_size // 2 - 2, test_size // 2 - 2)

test_contracting_block()
test_contracting_block(10, 9, 8)
print("Success!")
Success!

Expanding Path

Next, you will implement the expanding blocks for the expanding path. This is the decoding section of U-Net which has several upsampling steps as part of it. In order to do this, you'll also need to write a crop function. This is so you can crop the image from the contracting path and concatenate it to the current image on the expanding path—this is to form a skip connection. Again, the details are from the paper (Renneberger, 2015):

Every step in the expanding path consists of an upsampling of the feature map followed by a 2 x 2 convolution (“up-convolution”) that halves the number of feature channels, a concatenation with the correspondingly cropped feature map from the contracting path, and two 3 x 3 convolutions, each followed by a ReLU. The cropping is necessary due to the loss of border pixels in every convolution.

Fun fact: later models based on this architecture often use padding in the convolutions to prevent the size of the image from changing outside of the upsampling / downsampling steps!

Optional hint for ExpandingBlock 1. The concatenation means the number of channels goes back to being input_channels, so you need to halve it again for the next convolution.
In [7]:
# UNQ_C2 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED FUNCTION: crop
def crop(image, new_shape):
    '''
    Function for cropping an image tensor: Given an image tensor and the new shape,
    crops to the center pixels.
    Parameters:
        image: image tensor of shape (batch size, channels, height, width)
        new_shape: a torch.Size object with the shape you want x to have
    '''
    # There are many ways to implement this crop function, but it's what allows
    # the skip connection to function as intended with two differently sized images!
    #### START CODE HERE ####
    h = image.shape[2]
    w = image.shape[3]
    th = new_shape[2]
    tw = new_shape[3]
    i = int(round((h - th) / 2.))
    j = int(round((w - tw) / 2.))
    cropped_image = image[:, :, i:(i+th), j:(j+tw)]
    #### END CODE HERE ####
    return cropped_image
In [8]:
#UNIT TEST
def test_expanding_block_crop(test_samples=100, test_channels=10, test_size=100):
    # Make sure that the crop function is the right shape
    skip_con_x = torch.randn(test_samples, test_channels, test_size + 6, test_size + 6)
    x = torch.randn(test_samples, test_channels, test_size, test_size)
    cropped = crop(skip_con_x, x.shape)
    assert tuple(cropped.shape) == (test_samples, test_channels, test_size, test_size)

    # Make sure that the crop function takes the right area
    test_meshgrid = torch.meshgrid([torch.arange(0, test_size), torch.arange(0, test_size)])
    test_meshgrid = test_meshgrid[0] + test_meshgrid[1]
    test_meshgrid = test_meshgrid[None, None, :, :].float()
    cropped = crop(test_meshgrid, torch.Size([1, 1, test_size // 2, test_size // 2]))
    assert cropped.max() == (test_size - 1) * 2 - test_size // 2
    assert cropped.min() == test_size // 2
    assert cropped.mean() == test_size - 1

    test_meshgrid = torch.meshgrid([torch.arange(0, test_size), torch.arange(0, test_size)])
    test_meshgrid = test_meshgrid[0] + test_meshgrid[1]
    crop_size = 5
    test_meshgrid = test_meshgrid[None, None, :, :].float()
    cropped = crop(test_meshgrid, torch.Size([1, 1, crop_size, crop_size]))
    assert cropped.max() <= (test_size + crop_size - 1) and cropped.max() >= test_size - 1
    assert cropped.min() >= (test_size - crop_size - 1) and cropped.min() <= test_size - 1
    assert abs(cropped.mean() - test_size) <= 2

test_expanding_block_crop()
print("Success!")
Success!
In [9]:
# UNQ_C3 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED CLASS: ExpandingBlock
class ExpandingBlock(nn.Module):
    '''
    ExpandingBlock Class
    Performs an upsampling, a convolution, a concatenation of its two inputs,
    followed by two more convolutions.
    Values:
        input_channels: the number of channels to expect from a given input
    '''
    def __init__(self, input_channels):
        super(ExpandingBlock, self).__init__()
        # "Every step in the expanding path consists of an upsampling of the feature map"
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        # "followed by a 2x2 convolution that halves the number of feature channels"
        # "a concatenation with the correspondingly cropped feature map from the contracting path"
        # "and two 3x3 convolutions"
        #### START CODE HERE ####
        self.conv1 = nn.Conv2d(input_channels, input_channels//2, kernel_size=2)
        self.conv2 = nn.Conv2d(input_channels, input_channels//2, kernel_size=3)
        self.conv3 = nn.Conv2d(input_channels//2, input_channels//2, kernel_size=3)
        #### END CODE HERE ####
        self.activation = nn.ReLU() # "each followed by a ReLU"
 
    def forward(self, x, skip_con_x):
        '''
        Function for completing a forward pass of ExpandingBlock: 
        Given an image tensor, completes an expanding block and returns the transformed tensor.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
            skip_con_x: the image tensor from the contracting path (from the opposing block of x)
                    for the skip connection
        '''
        x = self.upsample(x)
        x = self.conv1(x)
        skip_con_x = crop(skip_con_x, x.shape)
        x = torch.cat([x, skip_con_x], axis=1)
        x = self.conv2(x)
        x = self.activation(x)
        x = self.conv3(x)
        x = self.activation(x)
        return x
    
    # Required for grading
    def get_self(self):
        return self
In [10]:
#UNIT TEST
def test_expanding_block(test_samples=100, test_channels=10, test_size=50):
    test_block = ExpandingBlock(test_channels)
    skip_con_x = torch.randn(test_samples, test_channels // 2, test_size * 2 + 6, test_size * 2 + 6)
    x = torch.randn(test_samples, test_channels, test_size, test_size)
    x = test_block.upsample(x)
    x = test_block.conv1(x)
    # Make sure that the first convolution produces the right shape
    assert tuple(x.shape) == (test_samples, test_channels // 2,  test_size * 2 - 1, test_size * 2 - 1)
    orginal_x = crop(skip_con_x, x.shape)
    x = torch.cat([x, orginal_x], axis=1)
    x = test_block.conv2(x)
    # Make sure that the second convolution produces the right shape
    assert tuple(x.shape) == (test_samples, test_channels // 2,  test_size * 2 - 3, test_size * 2 - 3)
    x = test_block.conv3(x)
    # Make sure that the final convolution produces the right shape
    assert tuple(x.shape) == (test_samples, test_channels // 2,  test_size * 2 - 5, test_size * 2 - 5)
    x = test_block.activation(x)

test_expanding_block()
print("Success!")
Success!

Final Layer

Now you will write the final feature mapping block, which takes in a tensor with arbitrarily many tensors and produces a tensor with the same number of pixels but with the correct number of output channels. From the paper (Renneberger, 2015):

At the final layer a 1x1 convolution is used to map each 64-component feature vector to the desired number of classes. In total the network has 23 convolutional layers.

In [11]:
# UNQ_C4 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED CLASS: FeatureMapBlock
class FeatureMapBlock(nn.Module):
    '''
    FeatureMapBlock Class
    The final layer of a UNet - 
    maps each pixel to a pixel with the correct number of output dimensions
    using a 1x1 convolution.
    Values:
        input_channels: the number of channels to expect from a given input
    '''
    def __init__(self, input_channels, output_channels):
        super(FeatureMapBlock, self).__init__()
        # "Every step in the expanding path consists of an upsampling of the feature map"
        #### START CODE HERE ####
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=1)
        #### END CODE HERE ####

    def forward(self, x):
        '''
        Function for completing a forward pass of FeatureMapBlock: 
        Given an image tensor, returns it mapped to the desired number of channels.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        x = self.conv(x)
        return x
In [12]:
# UNIT TEST
assert tuple(FeatureMapBlock(10, 60)(torch.randn(1, 10, 10, 10)).shape) == (1, 60, 10, 10)
print("Success!")
Success!

U-Net

Now you can put it all together! Here, you'll write a UNet class which will combine a series of the three kinds of blocks you've implemented.

In [13]:
# UNQ_C5 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED CLASS: UNet
class UNet(nn.Module):
    '''
    UNet Class
    A series of 4 contracting blocks followed by 4 expanding blocks to 
    transform an input image into the corresponding paired image, with an upfeature
    layer at the start and a downfeature layer at the end
    Values:
        input_channels: the number of channels to expect from a given input
        output_channels: the number of channels to expect for a given output
    '''
    def __init__(self, input_channels, output_channels, hidden_channels=64):
        super(UNet, self).__init__()
        # "Every step in the expanding path consists of an upsampling of the feature map"
        self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels)
        self.contract2 = ContractingBlock(hidden_channels * 2)
        self.contract3 = ContractingBlock(hidden_channels * 4)
        self.contract4 = ContractingBlock(hidden_channels * 8)
        self.expand1 = ExpandingBlock(hidden_channels * 16)
        self.expand2 = ExpandingBlock(hidden_channels * 8)
        self.expand3 = ExpandingBlock(hidden_channels * 4)
        self.expand4 = ExpandingBlock(hidden_channels * 2)
        self.downfeature = FeatureMapBlock(hidden_channels, output_channels)

    def forward(self, x):
        '''
        Function for completing a forward pass of UNet: 
        Given an image tensor, passes it through U-Net and returns the output.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        # Keep in mind that the expand function takes two inputs, 
        # both with the same number of channels. 
        #### START CODE HERE ####
        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.contract3(x2)
        x4 = self.contract4(x3)
        x5 = self.expand1(x4, x3)
        x6 = self.expand2(x5, x2)
        x7 = self.expand3(x6, x1)
        x8 = self.expand4(x7, x0)
        xn = self.downfeature(x8)
        #### END CODE HERE ####
        return xn
In [14]:
#UNIT TEST
test_unet = UNet(1, 3)
assert tuple(test_unet(torch.randn(1, 1, 256, 256)).shape) == (1, 3, 117, 117)
print("Success!")
Success!

Training

Finally, you will put this into action! Remember that these are your parameters:

  • criterion: the loss function
  • n_epochs: the number of times you iterate through the entire dataset when training
  • input_dim: the number of channels of the input image
  • label_dim: the number of channels of the output image
  • display_step: how often to display/visualize the images
  • batch_size: the number of images per forward/backward pass
  • lr: the learning rate
  • initial_shape: the size of the input image (in pixels)
  • target_shape: the size of the output image (in pixels)
  • device: the device type

This should take only a few minutes to train!

In [15]:
import torch.nn.functional as F
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
input_dim = 1
label_dim = 1
display_step = 20
batch_size = 4
lr = 0.0002
initial_shape = 512
target_shape = 373
device = 'cuda'
In [16]:
from skimage import io
import numpy as np
volumes = torch.Tensor(io.imread('train-volume.tif'))[:, None, :, :] / 255
labels = torch.Tensor(io.imread('train-labels.tif', plugin="tifffile"))[:, None, :, :] / 255
labels = crop(labels, torch.Size([len(labels), 1, target_shape, target_shape]))
dataset = torch.utils.data.TensorDataset(volumes, labels)
In [17]:
def train():
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True)
    unet = UNet(input_dim, label_dim).to(device)
    unet_opt = torch.optim.Adam(unet.parameters(), lr=lr)
    cur_step = 0

    for epoch in range(n_epochs):
        for real, labels in tqdm(dataloader):
            cur_batch_size = len(real)
            # Flatten the image
            real = real.to(device)
            labels = labels.to(device)

            ### Update U-Net ###
            unet_opt.zero_grad()
            pred = unet(real)
            unet_loss = criterion(pred, labels)
            unet_loss.backward()
            unet_opt.step()

            if cur_step % display_step == 0:
                print(f"Epoch {epoch}: Step {cur_step}: U-Net loss: {unet_loss.item()}")
                show_tensor_images(
                    crop(real, torch.Size([len(real), 1, target_shape, target_shape])), 
                    size=(input_dim, target_shape, target_shape)
                )
                show_tensor_images(labels, size=(label_dim, target_shape, target_shape))
                show_tensor_images(torch.sigmoid(pred), size=(label_dim, target_shape, target_shape))
            cur_step += 1

train()
Epoch 0: Step 0: U-Net loss: 0.6669554114341736


Epoch 2: Step 20: U-Net loss: 0.5429989099502563



Epoch 5: Step 40: U-Net loss: 0.5081177353858948


Epoch 7: Step 60: U-Net loss: 0.45750948786735535



Epoch 10: Step 80: U-Net loss: 0.3500884473323822


Epoch 12: Step 100: U-Net loss: 0.3130897879600525



Epoch 15: Step 120: U-Net loss: 0.3473833203315735


Epoch 17: Step 140: U-Net loss: 0.3616056740283966



Epoch 20: Step 160: U-Net loss: 0.36857712268829346


Epoch 22: Step 180: U-Net loss: 0.38156595826148987



Epoch 25: Step 200: U-Net loss: 0.3293497562408447


Epoch 27: Step 220: U-Net loss: 0.3103429973125458



Epoch 30: Step 240: U-Net loss: 0.29405784606933594


Epoch 32: Step 260: U-Net loss: 0.279386967420578



Epoch 35: Step 280: U-Net loss: 0.2760064899921417


Epoch 37: Step 300: U-Net loss: 0.29926183819770813



Epoch 40: Step 320: U-Net loss: 0.2984501123428345


Epoch 42: Step 340: U-Net loss: 0.2688570022583008



Epoch 45: Step 360: U-Net loss: 0.2959635257720947


Epoch 47: Step 380: U-Net loss: 0.28712165355682373



Epoch 50: Step 400: U-Net loss: 0.2800282835960388


Epoch 52: Step 420: U-Net loss: 0.24845512211322784



Epoch 55: Step 440: U-Net loss: 0.2401094287633896


Epoch 57: Step 460: U-Net loss: 0.2998526394367218



Epoch 60: Step 480: U-Net loss: 0.27664265036582947


Epoch 62: Step 500: U-Net loss: 0.27226734161376953



Epoch 65: Step 520: U-Net loss: 0.24700388312339783


Epoch 67: Step 540: U-Net loss: 0.20941729843616486



Epoch 70: Step 560: U-Net loss: 0.21105138957500458


Epoch 72: Step 580: U-Net loss: 0.20534472167491913



Epoch 75: Step 600: U-Net loss: 0.20689469575881958


Epoch 77: Step 620: U-Net loss: 0.19643478095531464



Epoch 80: Step 640: U-Net loss: 0.21871544420719147


Epoch 82: Step 660: U-Net loss: 0.1617809683084488



Epoch 85: Step 680: U-Net loss: 0.15854835510253906


Epoch 87: Step 700: U-Net loss: 0.16356070339679718



Epoch 90: Step 720: U-Net loss: 0.16745175421237946


Epoch 92: Step 740: U-Net loss: 0.13800491392612457



Epoch 95: Step 760: U-Net loss: 0.15077024698257446


Epoch 97: Step 780: U-Net loss: 0.14472010731697083



Epoch 100: Step 800: U-Net loss: 0.15366576611995697


Epoch 102: Step 820: U-Net loss: 0.13391107320785522



Epoch 105: Step 840: U-Net loss: 0.124368816614151


Epoch 107: Step 860: U-Net loss: 0.12332925945520401



Epoch 110: Step 880: U-Net loss: 0.1332693248987198


Epoch 112: Step 900: U-Net loss: 0.12691614031791687



Epoch 115: Step 920: U-Net loss: 0.12035534530878067


Epoch 117: Step 940: U-Net loss: 0.12262485921382904



Epoch 120: Step 960: U-Net loss: 0.11680576205253601


Epoch 122: Step 980: U-Net loss: 0.11665602773427963



Epoch 125: Step 1000: U-Net loss: 0.11054906994104385


Epoch 127: Step 1020: U-Net loss: 0.10503672808408737



Epoch 130: Step 1040: U-Net loss: 0.10351131856441498


Epoch 132: Step 1060: U-Net loss: 0.1042008027434349



Epoch 135: Step 1080: U-Net loss: 0.0992886945605278


Epoch 137: Step 1100: U-Net loss: 0.09839149564504623



Epoch 140: Step 1120: U-Net loss: 0.09309575706720352


Epoch 142: Step 1140: U-Net loss: 0.08533899486064911



Epoch 145: Step 1160: U-Net loss: 0.08815489709377289


Epoch 147: Step 1180: U-Net loss: 0.08567715436220169



Epoch 150: Step 1200: U-Net loss: 0.07848386466503143


Epoch 152: Step 1220: U-Net loss: 0.0819934755563736



Epoch 155: Step 1240: U-Net loss: 0.07530873268842697


Epoch 157: Step 1260: U-Net loss: 0.06388269364833832



Epoch 160: Step 1280: U-Net loss: 0.07382874190807343


Epoch 162: Step 1300: U-Net loss: 0.061365727335214615



Epoch 165: Step 1320: U-Net loss: 0.05920993164181709


Epoch 167: Step 1340: U-Net loss: 0.06318499147891998



Epoch 170: Step 1360: U-Net loss: 0.06700340658426285


Epoch 172: Step 1380: U-Net loss: 0.06866621971130371



Epoch 175: Step 1400: U-Net loss: 0.05597443878650665


Epoch 177: Step 1420: U-Net loss: 0.05495245382189751



Epoch 180: Step 1440: U-Net loss: 0.054204247891902924


Epoch 182: Step 1460: U-Net loss: 0.04995230585336685



Epoch 185: Step 1480: U-Net loss: 0.05663346126675606


Epoch 187: Step 1500: U-Net loss: 0.05510256439447403



Epoch 190: Step 1520: U-Net loss: 0.04957708716392517


Epoch 192: Step 1540: U-Net loss: 0.05308596044778824



Epoch 195: Step 1560: U-Net loss: 0.05588943511247635


Epoch 197: Step 1580: U-Net loss: 0.04845543950796127